if (!require("pacman")) install.packages("pacman")
## Loading required package: pacman
pacman::p_load(tidyverse, rjags, tidybayes)
knitr::opts_chunk$set(cache = TRUE)

Load data

pcv7_st <- c("4","6B","9V","14","18C","19F","23F")
pcv13_st <- c("1","3","4","5","6A","6B","7F","9V","14","18C","19A","19F","23F")
pcv15_st <- c("1","3","4","5","6A","6B","7F","9V","14","18C","19A","19F","22F","23F","33F")
pcv20_st <- c("1","3","4","5","6A","6B","7F","8","9V","10A","11A","12F","14","15B","18C","19A","19F","22F","23F","33F")

data_raw <- read_rds("data/ABCs_st_1998_2021.rds")

data <- data_raw |> 
  rename(agec = "Age.Group..years.",
         year = Year,
         st = IPD.Serotype,
         N_IPD = Frequency.Count)  |> 
  mutate(st = if_else(st == '16', '16F', st))  |> 
  group_by(st, year) |> 
  summarize(N_IPD = sum(N_IPD))  |> 
  ungroup() |> 
  complete(year, st, fill = list(N_IPD = 0)) |> 
  mutate(pcv7 = if_else(st %in% pcv7_st, 1, 0),
         pcv13 = if_else(st %in% pcv13_st, 1, 0),
         pcv15 = if_else(st %in% pcv15_st, 1, 0),
         pcv20 = if_else(st %in% pcv20_st, 1, 0),
         firstvax = case_when(pcv7 == 1 ~ "pcv7",
                              pcv13 == 1 ~ "pcv13",
                              pcv15 == 1 ~ "pcv15",
                              pcv20 == 1 ~ "pcv20",
                              .default = ""))
## `summarise()` has grouped output by 'st'. You can override using the `.groups`
## argument.

Mixture model with PCV-13 – uninformative delta

Wrangle data

# wrangle dataframe
mod9_df <- data |> 
  complete(st, year, fill = list(N_IPD = 0))  |> 
  arrange(st, year) |> 
  filter(st %in% pcv13_st)  |> 
  group_by(st) |> 
  arrange(year) |> 
  mutate(year_n = row_number()) |> 
  ungroup()

# convert to matrix
mod9_mat <- mod9_df |> 
  reshape2::dcast(year ~ st, value.var = 'N_IPD') |> 
  dplyr::select(-year)

Define model string

mod9_string <- "
  model {
    
    for(i in 1:N_years){
      
      for(j in 1:N_sts){
        N_IPD[i,j] ~ dnegbin(prob[i,j], r[j])
        
        prob[i,j] <- r[j] / (r[j] + lambda[i,j])  # Likelihood 
        
        log(lambda[i,j]) <- epsilon1[i,j]  # Serotype-specific intercept + AR(1) effect
      }
    }
    
    # Global AR(1) effect for beta1
    
    beta1[1] ~ dnorm(alpha1, (1 - rho_beta1^2) * tau_beta1)
    for(i in 2:N_years){
      beta1[i] ~ dnorm(alpha1 + rho_beta1 * beta1[i-1], tau_beta1)
    }
    
    # Global AR(1) effect for beta2
    
    beta2[1] ~ dnorm(alpha2, (1 - rho_beta2^2) * tau_beta2)
    for(i in 2:N_years){
      beta2[i] ~ dnorm(alpha2 + rho_beta2 * beta2[i-1], tau_beta2)
    }
        
    # Serotype-specific AR(1) effect with group selection
    for(j in 1:N_sts){
      epsilon1[1,j] ~ dnorm(delta1[j] + beta1[1] * grp[j] + beta2[1] * (1 - grp[j]), 
                            (1 - rho_eps^2) * tau_eps)
      
      for(i in 2:N_years){
        epsilon1[i,j] ~ dnorm(delta1[j] + beta1[i] * grp[j] + beta2[i] * (1 - grp[j]) + 
                              rho_eps * epsilon1[i-1, j], tau_eps)
      }
    }
    
    # Priors for group selection
    for(j in 1:N_sts){
      logit_pi[j] ~ dnorm(0, 1e-4)  # Prior on logit(pi)
      pi[j] <- exp(logit_pi[j]) / (exp(logit_pi[j]) + 1)  # Inverse logit
      grp[j] ~ dbern(pi[j])  # Group assignment - 0 or 1
    }
        
    # Priors
    alpha1 ~ dnorm(0, 1e-4)
    alpha2 ~ dnorm(0, 1e-4)
    tau_global ~ dgamma(0.01, 0.01)

    rho_beta1 ~ dunif(-1, 1)   # Uniform prior for rho_beta1 -- global AR(1) for beta1
    rho_beta2 ~ dunif(-1, 1)   # Uniform prior for rho_beta2 -- global AR(1) for beta2
    rho_eps ~ dunif(-1, 1)     # Prior for rho_eps same for all STs
    
    tau_beta1 ~ dgamma(3, 2)   # Tight prior for tau for beta1
    tau_beta2 ~ dgamma(3, 2)   # Tight prior for tau for beta2
    tau_eps ~ dgamma(3, 2)     # Tight prior for tau, shared for all serotypes
    
    for(j in 1:N_sts){
      delta1[j] ~ dnorm(0, 1e-4)  # Serotype means uninformative
      r[j] ~ dunif(0, 250)  # Serotype dispersion parameter
    }
  }
"

Specify initial values

inits1 = list(".RNG.seed"=c(123), ".RNG.name"='base::Wichmann-Hill')
inits2 = list(".RNG.seed"=c(456), ".RNG.name"='base::Wichmann-Hill')
inits3 = list(".RNG.seed"=c(789), ".RNG.name"='base::Wichmann-Hill')

Create JAGS model object

mod9 <- jags.model(textConnection(mod9_string),
                   data = list("N_IPD" = mod9_mat,
                               "N_years" = max(mod9_df$year_n),
                               "N_sts" = ncol(mod9_mat)),
                   inits = list(inits1, inits2, inits3),
                   n.adapt = 10000,
                   n.chains = 3)

Sample posteriors

params9 <- c("alpha1", "epsilon1", "delta1", "beta1", "beta2", "tau_beta1", "lambda", "grp")

mod9_postsamp <- coda.samples(mod9, params9, n.iter = 50000)

# save(mod9_postsamp, file = "data/interim/mod9_postsamp.rda")
load("data/interim/mod9_postsamp.rda")

Plot traces for beta 1 and beta 2

beta_vars9 <- grep("^beta", varnames(mod9_postsamp), value = TRUE)
for (i in seq_along(beta_vars9)) {
  cat(paste0("\n#### ", beta_vars9[i]," \n"))
  cat("\n")
  traceplot(mod9_postsamp[, beta_vars9[i]], main = beta_vars9[i])
  cat("\n")
}

beta1[1]

beta1[2]

beta1[3]

beta1[4]

beta1[5]

beta1[6]

beta1[7]

beta1[8]

beta1[9]

beta1[10]

beta1[11]

beta1[12]

beta1[13]

beta1[14]

beta1[15]

beta1[16]

beta1[17]

beta1[18]

beta1[19]

beta1[20]

beta1[21]

beta1[22]

beta1[23]

beta1[24]

beta1[25]

beta2[1]

beta2[2]

beta2[3]

beta2[4]

beta2[5]

beta2[6]

beta2[7]

beta2[8]

beta2[9]

beta2[10]

beta2[11]

beta2[12]

beta2[13]

beta2[14]

beta2[15]

beta2[16]

beta2[17]

beta2[18]

beta2[19]

beta2[20]

beta2[21]

beta2[22]

beta2[23]

beta2[24]

beta2[25]

Show and plot AR1 groups

st_mapping9 <- setNames(colnames(mod9_mat), 1:length(pcv13_st))
year_mapping9 <- setNames(unique(mod9_df$year), 1:length(unique(mod9_df$year)))

groups9 <- gather_draws(mod9_postsamp, grp[j]) |> 
  median_hdi() |> 
  rename(st = j) |> 
  mutate(st = st_mapping9[as.character(st)],
         group = if_else(.value > 0.5, "Group 1 (beta1)", "Group 2 (beta 2)"))

groups9 |> 
  group_by(group) |> 
  summarise(serotypes = paste0(st, collapse = ", "))
groups9 |> 
  ggplot(aes(x = st)) +
  geom_pointrange(aes(y = .value, ymin = .lower, ymax = .upper)) +
  theme_bw()

Plot trajectory of global AR(1)s

AR1 - beta1

mod9_beta1 <- gather_draws(mod9_postsamp, beta1[i]) |> 
  median_hdi() |> 
  mutate(year = year_mapping9[as.character(i)])

mod9_beta1 |> 
  ggplot(aes(x = year, y = .value)) +
  geom_ribbon(aes(x = year, ymin = .lower, ymax = .upper), alpha = 0.2) +
  geom_line() +
  labs(x = "", y = "Global AR(1) - Beta 1") +
  theme_minimal()

AR1 - beta2

mod9_beta2 <- gather_draws(mod9_postsamp, beta2[i]) |> 
  median_hdi() |> 
  mutate(year = year_mapping9[as.character(i)])

mod9_beta2 |> 
  ggplot(aes(x = year, y = .value)) +
  geom_ribbon(aes(x = year, ymin = .lower, ymax = .upper), alpha = 0.2) +
  geom_line() +
  labs(x = "", y = "Global AR(1) - Beta 2") +
  theme_minimal()

Plot obs vs pred

mod9_summary <- gather_draws(mod9_postsamp, lambda[i,j]) |> 
  median_hdi() |> 
  rename(year = i,
         st = j) |> 
  mutate(st = st_mapping9[as.character(st)],
         year = year_mapping9[as.character(year)]) |> 
  left_join(mod9_df, by = c("year", "st"))

Standard

mod9_summary |> 
  ggplot(aes(x = year, y = .value)) +
  geom_ribbon(aes(x = year, ymin = .lower, ymax = .upper), alpha = 0.2) +
  geom_point(aes(y = N_IPD, color = "Observed")) +
  geom_line(aes(color = "Fitted")) +
  facet_wrap(~st) +
  labs(x = "", y = "# IPD", color = "") +
  theme_minimal()

Log-transformed y-axis

mod9_summary |> 
  ggplot(aes(x = year, y = .value)) +
  geom_ribbon(aes(x = year, ymin = .lower, ymax = .upper), alpha = 0.2) +
  geom_point(aes(y = N_IPD, color = "Observed")) +
  geom_line(aes(color = "Fitted")) +
  facet_wrap(~st) +
  labs(x = "", y = "# IPD", color = "") +
  theme_minimal() +
  coord_trans(y = "log1p")

Free y-axis

mod9_summary |> 
  ggplot(aes(x = year, y = .value)) +
  geom_ribbon(aes(x = year, ymin = .lower, ymax = .upper), alpha = 0.2) +
  geom_point(aes(y = N_IPD, color = "Observed")) +
  geom_line(aes(color = "Fitted")) +
  facet_wrap(~st, scales = "free_y") +
  labs(x = "", y = "# IPD", color = "") +
  theme_minimal()

Mixture model with PCV-13 – delta centered around 1

Define model string

mod10_string <- "
  model {
    
    for(i in 1:N_years){
      
      for(j in 1:N_sts){
        N_IPD[i,j] ~ dnegbin(prob[i,j], r[j])
        
        prob[i,j] <- r[j] / (r[j] + lambda[i,j])  # Likelihood 
        
        log(lambda[i,j]) <- epsilon1[i,j]  # Serotype-specific intercept + AR(1) effect
      }
    }
    
    # Global AR(1) effect for beta1
    
    beta1[1] ~ dnorm(alpha1, (1 - rho_beta1^2) * tau_beta1)
    for(i in 2:N_years){
      beta1[i] ~ dnorm(alpha1 + rho_beta1 * beta1[i-1], tau_beta1)
    }
    
    # Global AR(1) effect for beta2
    
    beta2[1] ~ dnorm(alpha2, (1 - rho_beta2^2) * tau_beta2)
    for(i in 2:N_years){
      beta2[i] ~ dnorm(alpha2 + rho_beta2 * beta2[i-1], tau_beta2)
    }
        
    # Serotype-specific AR(1) effect with group selection
    for(j in 1:N_sts){
      epsilon1[1,j] ~ dnorm(delta1[j] + beta1[1] * grp[j] + beta2[1] * (1 - grp[j]), 
                            (1 - rho_eps^2) * tau_eps)
      
      for(i in 2:N_years){
        epsilon1[i,j] ~ dnorm(delta1[j] + beta1[i] * grp[j] + beta2[i] * (1 - grp[j]) + 
                              rho_eps * epsilon1[i-1, j], tau_eps)
      }
    }
    
    # Priors for group selection
    for(j in 1:N_sts){
      grp[j] ~ dbern(pi)  # Group assignment - 0 or 1
    }
    
    pi ~ dunif(0,1)
        
    # Priors
    alpha1 ~ dnorm(0, 1e-4)
    alpha2 ~ dnorm(0, 1e-4)
    tau_global ~ dgamma(0.01, 0.01)

    rho_beta1 ~ dunif(-1, 1)   # Uniform prior for rho_beta1 -- global AR(1) for beta1
    rho_beta2 ~ dunif(-1, 1)   # Uniform prior for rho_beta2 -- global AR(1) for beta2
    rho_eps ~ dunif(-1, 1)     # Prior for rho_eps same for all STs
    
    tau_beta1 ~ dgamma(3, 2)   # Tight prior for tau for beta1
    tau_beta2 ~ dgamma(3, 2)   # Tight prior for tau for beta2
    tau_eps ~ dgamma(3, 2)     # Tight prior for tau, shared for all serotypes
    
    for(j in 1:N_sts){
      delta1[j] ~ dnorm(0, tau_global)  # Serotype means centered around 0
      r[j] ~ dunif(0, 250)  # Serotype dispersion parameter
    }
  }
"

Create JAGS model object

mod10 <- jags.model(textConnection(mod10_string),
                   data = list("N_IPD" = mod9_mat,
                               "N_years" = max(mod9_df$year_n),
                               "N_sts" = ncol(mod9_mat)),
                   inits = list(inits1, inits2, inits3),
                   n.adapt = 10000,
                   n.chains = 3)

Sample posteriors

params10 <- c("alpha1", "alpha2", "epsilon1", "delta1", "beta1", "beta2", "tau_beta1", "tau_beta2", "lambda", "rho_beta1", "rho_beta2", "rho_eps", "pi", "grp", "r")

mod10_postsamp <- coda.samples(mod10, params10, n.iter = 10000)

# save(mod10_postsamp, file = "data/interim/mod10_postsamp.rda")
load("data/interim/mod10_postsamp.rda")

Plot traces for beta 1 and beta 2

beta_vars10 <- grep("^beta", varnames(mod10_postsamp), value = TRUE)
for (i in seq_along(beta_vars10)) {
  cat(paste0("\n#### ", beta_vars10[i]," \n"))
  cat("\n")
  traceplot(mod10_postsamp[, beta_vars10[i]], main = beta_vars10[i])
  cat("\n")
}

beta1[1]

beta1[2]

beta1[3]

beta1[4]

beta1[5]

beta1[6]

beta1[7]

beta1[8]

beta1[9]

beta1[10]

beta1[11]

beta1[12]

beta1[13]

beta1[14]

beta1[15]

beta1[16]

beta1[17]

beta1[18]

beta1[19]

beta1[20]

beta1[21]

beta1[22]

beta1[23]

beta1[24]

beta1[25]

beta2[1]

beta2[2]

beta2[3]

beta2[4]

beta2[5]

beta2[6]

beta2[7]

beta2[8]

beta2[9]

beta2[10]

beta2[11]

beta2[12]

beta2[13]

beta2[14]

beta2[15]

beta2[16]

beta2[17]

beta2[18]

beta2[19]

beta2[20]

beta2[21]

beta2[22]

beta2[23]

beta2[24]

beta2[25]

Show and plot AR1 groups

groups10 <- gather_draws(mod10_postsamp, grp[j]) |> 
  median_hdi() |> 
  rename(st = j) |> 
  mutate(st = st_mapping9[as.character(st)],
         group = if_else(.value > 0.5, "Group 1 (beta1)", "Group 2 (beta 2)"))

groups10 |> 
  group_by(group) |> 
  summarise(serotypes = paste0(st, collapse = ", "))
groups10 |> 
  ggplot(aes(x = st)) +
  geom_pointrange(aes(y = .value, ymin = .lower, ymax = .upper)) +
  theme_bw()

Plot trajectory of global AR(1)s

AR1 - beta1

mod10_beta1 <- gather_draws(mod10_postsamp, beta1[i]) |> 
  median_hdi() |> 
  mutate(year = year_mapping9[as.character(i)])

mod10_beta1 |> 
  ggplot(aes(x = year, y = .value)) +
  geom_ribbon(aes(x = year, ymin = .lower, ymax = .upper), alpha = 0.2) +
  geom_line() +
  labs(x = "", y = "Global AR(1) - Beta 1") +
  theme_minimal()

AR1 - beta2

mod10_beta2 <- gather_draws(mod10_postsamp, beta2[i]) |> 
  median_hdi() |> 
  mutate(year = year_mapping9[as.character(i)])

mod10_beta2 |> 
  ggplot(aes(x = year, y = .value)) +
  geom_ribbon(aes(x = year, ymin = .lower, ymax = .upper), alpha = 0.2) +
  geom_line() +
  labs(x = "", y = "Global AR(1) - Beta 2") +
  theme_minimal()

Plot obs vs pred

mod10_summary <- gather_draws(mod10_postsamp, lambda[i,j]) |> 
  median_hdi() |> 
  rename(year = i,
         st = j) |> 
  mutate(st = st_mapping9[as.character(st)],
         year = year_mapping9[as.character(year)]) |> 
  left_join(mod9_df, by = c("year", "st"))

Standard

mod10_summary |> 
  ggplot(aes(x = year, y = .value)) +
  geom_ribbon(aes(x = year, ymin = .lower, ymax = .upper), alpha = 0.2) +
  geom_point(aes(y = N_IPD, color = "Observed")) +
  geom_line(aes(color = "Fitted")) +
  facet_wrap(~st) +
  labs(x = "", y = "# IPD", color = "") +
  theme_minimal()

Log-transformed y-axis

mod10_summary |> 
  ggplot(aes(x = year, y = .value)) +
  geom_ribbon(aes(x = year, ymin = .lower, ymax = .upper), alpha = 0.2) +
  geom_point(aes(y = N_IPD, color = "Observed")) +
  geom_line(aes(color = "Fitted")) +
  facet_wrap(~st) +
  labs(x = "", y = "# IPD", color = "") +
  theme_minimal() +
  coord_trans(y = "log1p")

Free y-axis

mod10_summary |> 
  ggplot(aes(x = year, y = .value)) +
  geom_ribbon(aes(x = year, ymin = .lower, ymax = .upper), alpha = 0.2) +
  geom_point(aes(y = N_IPD, color = "Observed")) +
  geom_line(aes(color = "Fitted")) +
  facet_wrap(~st, scales = "free_y") +
  labs(x = "", y = "# IPD", color = "") +
  theme_minimal()